# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""expected output of SharedExpertMLP"""
import numpy as np
import mindspore as ms
from mindformers.parallel_core.transformer_config import TransformerConfig


def get_init_params(config: TransformerConfig, seq_length=4, batch_size=2):
    """Generate initialization parameters"""
    np.random.seed(1)
    state_dict = {
        "mlp.linear_fc1.weight": 0.01 * np.random.rand(config.ffn_hidden_size, config.hidden_size),
        "mlp.linear_fc2.weight": 0.01 * np.random.rand(config.hidden_size, config.ffn_hidden_size),
        "mlp.shared_experts_gate.weight": 0.01 * np.random.rand(1, config.hidden_size)
    }
    for k in state_dict:
        state_dict[k] = ms.Parameter(ms.tensor(state_dict[k], dtype=ms.float32))

    input_ = np.random.rand(seq_length, batch_size, config.hidden_size)
    input_ = ms.tensor(input_, dtype=ms.bfloat16)

    return input_, state_dict


def get_golden_datas(args) -> dict[str, np.ndarray]:
    """Generate golden data for test."""
    if not args.gate:
        output_golden = np.array(
            [[[0.0019919416, 0.0020969177, 0.0019119657, 0.0016831061,
               0.0015097546, 0.0024931964, 0.0022334943, 0.0015607530,
               0.0016224522, 0.0018226021, 0.0013644423, 0.0020328050,
               0.0016648017, 0.0019004948, 0.0018370659, 0.0017043886],
              [0.0015751260, 0.0016880741, 0.0015236203, 0.0013121577,
               0.0011900065, 0.0019737678, 0.0017540415, 0.0012518927,
               0.0012730283, 0.0014253005, 0.0010885255, 0.0016146263,
               0.0013117085, 0.0015241898, 0.0014786809, 0.0013727419]],
             [[0.0019514252, 0.0020915624, 0.0018961922, 0.0016189103,
               0.0014609485, 0.0024405473, 0.0021871955, 0.0015791889,
               0.0015751384, 0.0017856713, 0.0013288374, 0.0019960622,
               0.0016252893, 0.0018935392, 0.0017961639, 0.0017028205],
              [0.0018867148, 0.0020168365, 0.0018601861, 0.0015550752,
               0.0014348353, 0.0022915448, 0.0021151179, 0.0015194945,
               0.0015332238, 0.0017510453, 0.0013417912, 0.0019695892,
               0.0015768812, 0.0018295512, 0.0017305955, 0.0016419239]],
             [[0.0014538650, 0.0015827805, 0.0014461200, 0.0012378941,
               0.0010918338, 0.0018326158, 0.0016430839, 0.0011372042,
               0.0012296749, 0.0013730322, 0.0010569921, 0.0015125085,
               0.0011989862, 0.0014621870, 0.0013702807, 0.0012903372],
              [0.0017733796, 0.0019430587, 0.0017720656, 0.0014609765,
               0.0013720200, 0.0021825042, 0.0020178643, 0.0014793403,
               0.0014748116, 0.0016822204, 0.0012621277, 0.0018532595,
               0.0014871625, 0.0017542860, 0.0016103230, 0.0015683112]],
             [[0.0018832930, 0.0020078733, 0.0018468696, 0.0015462689,
               0.0014222845, 0.0022924296, 0.0021196704, 0.0015214478,
               0.0015399234, 0.0017200317, 0.0013087805, 0.0019487045,
               0.0015966668, 0.0018407278, 0.0016958695, 0.0016332783],
              [0.0013505130, 0.0014500835, 0.0013241110, 0.0011225898,
               0.0010145908, 0.0016514058, 0.0015427083, 0.0010856654,
               0.0011017679, 0.0012138544, 0.0009260966, 0.0014120181,
               0.0011524413, 0.0013224644, 0.0012130215, 0.0011502546]]], dtype=np.float32)
    else:
        output_golden = np.array(
            [[[0.0010165554, 0.0010701282, 0.0009757409, 0.0008589461,
               0.0007704789, 0.0012723627, 0.0011398279, 0.0007965051,
               0.0008279923, 0.0009301357, 0.0006963211, 0.0010374093,
               0.0008496048, 0.0009698869, 0.0009375170, 0.0008698073],
              [0.0008018822, 0.0008593830, 0.0007756611, 0.0006680075,
               0.0006058214, 0.0010048271, 0.0008929665, 0.0006373271,
               0.0006480871, 0.0007256074, 0.0005541584, 0.0008219914,
               0.0006677788, 0.0007759511, 0.0007527829, 0.0006988503]],
             [[0.0009986131, 0.0010703262, 0.0009703484, 0.0008284535,
               0.0007476188, 0.0012489142, 0.0011192651, 0.0008081266,
               0.0008060539, 0.0009137910, 0.0006800129, 0.0010214554,
               0.0008317179, 0.0009689908, 0.0009191604, 0.0008713933],
              [0.0009630753, 0.0010294961, 0.0009495337, 0.0007937895,
               0.0007324130, 0.0011697211, 0.0010796639, 0.0007756273,
               0.0007826355, 0.0008938227, 0.0006849185, 0.0010053786,
               0.0008049204, 0.0009338961, 0.0008833840, 0.0008381215]],
             [[0.0007348188, 0.0007999759, 0.0007309043, 0.0006256619,
               0.0005518394, 0.0009262487, 0.0008304546, 0.0005747707,
               0.0006215077, 0.0006939640, 0.0005342296, 0.0007644587,
               0.0006059969, 0.0007390250, 0.0006925733, 0.0006521679],
              [0.0009029080, 0.0009892994, 0.0009022391, 0.0007438494,
               0.0006985576, 0.0011112120, 0.0010273863, 0.0007531993,
               0.0007508935, 0.0008564947, 0.0006426065, 0.0009435785,
               0.0007571820, 0.0008931866, 0.0008198885, 0.0007984984]],
             [[0.0009632797, 0.0010270008, 0.0009446496, 0.0007908963,
               0.0007274799, 0.0011725477, 0.0010841836, 0.0007782006,
               0.0007876506, 0.0008797737, 0.0006694241, 0.0009967368,
               0.0008166741, 0.0009415081, 0.0008674150, 0.0008354004],
              [0.0006846662, 0.0007351452, 0.0006712813, 0.0005691166,
               0.0005143646, 0.0008372092, 0.0007821030, 0.0005503971,
               0.0005585605, 0.0006153848, 0.0004695009, 0.0007158474,
               0.0005842503, 0.0006704465, 0.0006149625, 0.0005831418]]], dtype=np.float32)

    return output_golden


def get_gpu_datas(args) -> dict[str, np.ndarray]:
    """Generate gpu data for test."""
    if not args.gate:
        output_gpu = np.array(
            [[[0.0019989014, 0.0020904541, 0.0019149780, 0.0016860962,
               0.0015106201, 0.0024871826, 0.0022277832, 0.0015563965,
               0.0016174316, 0.0018234253, 0.0013656616, 0.0020294189,
               0.0016632080, 0.0018997192, 0.0018386841, 0.0017013550],
              [0.0015716553, 0.0016860962, 0.0015258789, 0.0013122559,
               0.0011901855, 0.0019683838, 0.0017547607, 0.0012512207,
               0.0012741089, 0.0014266968, 0.0010910034, 0.0016174316,
               0.0013122559, 0.0015258789, 0.0014801025, 0.0013732910]],
             [[0.0019531250, 0.0020904541, 0.0018997192, 0.0016174316,
               0.0014572144, 0.0024414062, 0.0021820068, 0.0015792847,
               0.0015792847, 0.0017852783, 0.0013275146, 0.0019989014,
               0.0016250610, 0.0018920898, 0.0017929077, 0.0017013550],
              [0.0018844604, 0.0020141602, 0.0018615723, 0.0015563965,
               0.0014343262, 0.0022888184, 0.0021209717, 0.0015182495,
               0.0015335083, 0.0017471313, 0.0013427734, 0.0019683838,
               0.0015716553, 0.0018310547, 0.0017318726, 0.0016403198]],
             [[0.0014572144, 0.0015869141, 0.0014495850, 0.0012359619,
               0.0010910034, 0.0018310547, 0.0016479492, 0.0011367798,
               0.0012283325, 0.0013732910, 0.0010604858, 0.0015106201,
               0.0011978149, 0.0014648438, 0.0013732910, 0.0012893677],
              [0.0017776489, 0.0019454956, 0.0017700195, 0.0014648438,
               0.0013732910, 0.0021820068, 0.0020141602, 0.0014801025,
               0.0014724731, 0.0016784668, 0.0012588501, 0.0018539429,
               0.0014877319, 0.0017547607, 0.0016098022, 0.0015640259]],
             [[0.0018844604, 0.0020141602, 0.0018463135, 0.0015487671,
               0.0014190674, 0.0022888184, 0.0021209717, 0.0015182495,
               0.0015411377, 0.0017166138, 0.0013046265, 0.0019454956,
               0.0015945435, 0.0018386841, 0.0016937256, 0.0016326904],
              [0.0013504028, 0.0014495850, 0.0013275146, 0.0011215210,
               0.0010147095, 0.0016479492, 0.0015411377, 0.0010833740,
               0.0011062622, 0.0012130737, 0.0009269714, 0.0014114380,
               0.0011520386, 0.0013275146, 0.0012130737, 0.0011520386]]], dtype=np.float16)
    else:
        output_gpu = np.array(
            [[[0.0010223389, 0.0010681152, 0.0009765625, 0.0008621216,
               0.0007743835, 0.0012741089, 0.0011367798, 0.0007972717,
               0.0008277893, 0.0009346008, 0.0006980896, 0.0010375977,
               0.0008506775, 0.0009727478, 0.0009422302, 0.0008697510],
              [0.0007972717, 0.0008544922, 0.0007743835, 0.0006675720,
               0.0006027222, 0.0009994507, 0.0008926392, 0.0006370544,
               0.0006484985, 0.0007247925, 0.0005531311, 0.0008201599,
               0.0006675720, 0.0007743835, 0.0007514954, 0.0006980896]],
             [[0.0009994507, 0.0010681152, 0.0009727478, 0.0008277893,
               0.0007438660, 0.0012512207, 0.0011138916, 0.0008087158,
               0.0008087158, 0.0009117126, 0.0006790161, 0.0010223389,
               0.0008316040, 0.0009689331, 0.0009193420, 0.0008697510],
              [0.0009651184, 0.0010299683, 0.0009536743, 0.0007972717,
               0.0007324219, 0.0011749268, 0.0010833740, 0.0007781982,
               0.0007858276, 0.0008926392, 0.0006866455, 0.0010070801,
               0.0008049011, 0.0009384155, 0.0008850098, 0.0008392334]],
             [[0.0007324219, 0.0008010864, 0.0007286072, 0.0006217957,
               0.0005493164, 0.0009231567, 0.0008316040, 0.0005722046,
               0.0006179810, 0.0006904602, 0.0005340576, 0.0007629395,
               0.0006027222, 0.0007400513, 0.0006904602, 0.0006484985],
              [0.0009040833, 0.0009841919, 0.0009002686, 0.0007438660,
               0.0006980896, 0.0011062622, 0.0010223389, 0.0007514954,
               0.0007476807, 0.0008506775, 0.0006408691, 0.0009422302,
               0.0007553101, 0.0008926392, 0.0008163452, 0.0007934570]],
             [[0.0009651184, 0.0010299683, 0.0009460449, 0.0007934570,
               0.0007247925, 0.0011749268, 0.0010833740, 0.0007781982,
               0.0007896423, 0.0008773804, 0.0006675720, 0.0009918213,
               0.0008163452, 0.0009422302, 0.0008659363, 0.0008354187],
              [0.0006866455, 0.0007362366, 0.0006752014, 0.0005683899,
               0.0005149841, 0.0008354187, 0.0007820129, 0.0005493164,
               0.0005607605, 0.0006141663, 0.0004711151, 0.0007171631,
               0.0005836487, 0.0006752014, 0.0006141663, 0.0005836487]]], dtype=np.float16)

    return output_gpu
